import pygame
from threading import Thread
from queue import Queue
import time
import numpy as np
from mini_bdx_runtime.buttons import Buttons


AXIS_DEAD_ZONE = 0.1

X_RANGE = [-0.15, 0.15]
Y_RANGE = [-0.2, 0.2]
YAW_RANGE = [-1.0, 1.0]

# rads
NECK_PITCH_RANGE = [-0.34, 1.1]
HEAD_PITCH_RANGE = [-0.78, 0.3]
HEAD_YAW_RANGE = [-0.5, 0.5]
HEAD_ROLL_RANGE = [-0.5, 0.5]


class XBoxController:
    def __init__(self, command_freq, only_head_control=False):
        self.command_freq = command_freq
        self.head_control_mode = only_head_control
        self.only_head_control = only_head_control

        self.last_commands = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
        self.last_left_trigger = 0.0
        self.last_right_trigger = 0.0

        while(True):
            try:
                print("Trying to connect to the joystick...")
                pygame.init()
                self.p1 = pygame.joystick.Joystick(0)
                self.p1.init()
                print(f"Loaded joystick with {self.p1.get_numaxes()} axes.")
                break
            except pygame.error:
                pygame.quit()
                self.p1 = None
                print("Warning: No joystick found.")
                time.sleep(1)

        self.cmd_queue = Queue(maxsize=1)

        self.A_pressed = False
        self.B_pressed = False
        self.X_pressed = False
        self.Y_pressed = False
        self.LB_pressed = False
        self.RB_pressed = False

        self.buttons = Buttons()

        Thread(target=self.commands_worker, daemon=True).start()

    def commands_worker(self):
        while True:
            self.cmd_queue.put(self.get_commands())
            time.sleep(1 / self.command_freq)

    # Running in background thread!!!
    def get_commands(self):
        last_commands = self.last_commands
        left_trigger = self.last_left_trigger
        right_trigger = self.last_right_trigger

        # debug test
        # print(f'axis[0]: {self.p1.get_axis(0)}, axis[1]: {self.p1.get_axis(1)}, axis[2]: {self.p1.get_axis(2)}, axis[3]: {self.p1.get_axis(3)}, axis[4]: {self.p1.get_axis(4)}, axis[5]: {self.p1.get_axis(5)}')

        l_x = self.p1.get_axis(0)
        l_x = l_x if abs(l_x) > AXIS_DEAD_ZONE else 0 # erase dead zone

        l_y = self.p1.get_axis(1)
        l_y = l_y if abs(l_y) > AXIS_DEAD_ZONE else 0 # erase dead zone

        left_trigger = np.around((self.p1.get_axis(2) + 1) / 2, 3)
        if left_trigger < AXIS_DEAD_ZONE: left_trigger = 0

        r_x = self.p1.get_axis(3)
        r_x = r_x if abs(r_x) > AXIS_DEAD_ZONE else 0 # erase dead zone

        r_y = self.p1.get_axis(4)
        r_y = r_y if abs(r_y) > AXIS_DEAD_ZONE else 0 # erase dead zone

        right_trigger = np.around((self.p1.get_axis(5) + 1) / 2, 3)
        if right_trigger < AXIS_DEAD_ZONE: right_trigger = 0

        if not self.head_control_mode:
            lin_vel_x = -l_y
            lin_vel_y = -l_x
            ang_vel = -r_x

            if lin_vel_x >= 0:
                lin_vel_x *= np.abs(X_RANGE[1])
            else:
                lin_vel_x *= np.abs(X_RANGE[0])

            if lin_vel_y >= 0:
                lin_vel_y *= np.abs(Y_RANGE[1])
            else:
                lin_vel_y *= np.abs(Y_RANGE[0])

            if ang_vel >= 0:
                ang_vel *= np.abs(YAW_RANGE[1])
            else:
                ang_vel *= np.abs(YAW_RANGE[0])

            last_commands[0] = lin_vel_x
            last_commands[1] = lin_vel_y
            last_commands[2] = ang_vel
        else:
            # richard todo, has not been tested for PS4 joystick in head control mode...
            last_commands[0] = 0.0
            last_commands[1] = 0.0
            last_commands[2] = 0.0
            last_commands[3] = 0.0  # neck pitch 0 for now

            head_yaw = l_x
            head_pitch = l_y
            head_roll = r_x

            if head_yaw >= 0:
                head_yaw *= np.abs(HEAD_YAW_RANGE[0])
            else:
                head_yaw *= np.abs(HEAD_YAW_RANGE[1])

            if head_pitch >= 0:
                head_pitch *= np.abs(HEAD_PITCH_RANGE[0])
            else:
                head_pitch *= np.abs(HEAD_PITCH_RANGE[1])

            if head_roll >= 0:
                head_roll *= np.abs(HEAD_ROLL_RANGE[0])
            else:
                head_roll *= np.abs(HEAD_ROLL_RANGE[1])

            last_commands[4] = head_pitch
            last_commands[5] = head_yaw
            last_commands[6] = head_roll

        for event in pygame.event.get():
            if event.type == pygame.JOYBUTTONDOWN:

                if self.p1.get_button(0):  # A button
                    self.A_pressed = True

                if self.p1.get_button(1):  # B button
                    self.B_pressed = True

                if self.p1.get_button(3):  # X button
                    self.X_pressed = True

                if self.p1.get_button(2):  # Y button
                    self.Y_pressed = True
                    if not self.only_head_control:
                        self.head_control_mode = not self.head_control_mode

                if self.p1.get_button(4):  # LB button
                    self.LB_pressed = True

                if self.p1.get_button(5):  # RB button
                    self.RB_pressed = True

            if event.type == pygame.JOYBUTTONUP:
                self.A_pressed = False
                self.B_pressed = False
                self.X_pressed = False
                self.Y_pressed = False
                self.LB_pressed = False
                self.RB_pressed = False

            # for i in range(10):
            #     if self.p1.get_button(i):
            #         print(f"Button {i} pressed")

        up_down = self.p1.get_hat(0)[1]
        pygame.event.pump()  # process event queue

        return (
            np.around(last_commands, 3),
            self.A_pressed,
            self.B_pressed,
            self.X_pressed,
            self.Y_pressed,
            self.LB_pressed,
            self.RB_pressed,
            left_trigger,
            right_trigger,
            up_down,
        )

    # Calling in main thread
    def get_last_command(self):
        A_pressed = False
        B_pressed = False
        X_pressed = False
        Y_pressed = False
        LB_pressed = False
        RB_pressed = False
        up_down = 0
        try:
            (
                self.last_commands,
                A_pressed,
                B_pressed,
                X_pressed,
                Y_pressed,
                LB_pressed,
                RB_pressed,
                self.last_left_trigger,
                self.last_right_trigger,
                up_down,
            ) = self.cmd_queue.get(
                False
            )  # non blocking
        except Exception:
            pass

        self.buttons.update(
            A_pressed,
            B_pressed,
            X_pressed,
            Y_pressed,
            LB_pressed,
            RB_pressed,
            up_down == 1,
            up_down == -1,
        )

        return (
            self.last_commands,
            self.buttons,
            self.last_left_trigger,
            self.last_right_trigger,
        )
    

def float2str(num):
    res = f'{num:.3f}'
    
    # 转换为字符串并处理正负号
    if num >= 0.:
        res = "+" + res
    
    return res

if __name__ == "__main__":
    from rich.live import Live
    from rich.table import Table
    from mini_bdx_runtime.keyboard_utils import is_any_key_pressed

    controller = XBoxController(20)

    with Live(refresh_per_second=30) as live:
        while True:
            # print(controller.get_last_command())
            (commands, buttons, left_trigger, right_trigger) = controller.get_last_command()

            table = Table(title = 'PS4 Joystick', show_header=True, show_lines=True)

            table.add_column('', width = 5)

            table.add_column('Forward', width = 7)
            table.add_column('Strafe', width = 7)
            table.add_column('Rotate', width = 7)

            table.add_column('A', width = 6)
            table.add_column('B', width = 6)
            table.add_column('X', width = 6)
            table.add_column('Y', width = 6)

            table.add_column('LB', width = 6)
            table.add_column('RB', width = 6)
            table.add_column('LT', width = 7)
            table.add_column('RT', width = 7)

            table.add_row(
                'Final', 
                float2str(commands[0]), float2str(commands[1]), float2str(commands[2]), 
                str(buttons.A.is_pressed), str(buttons.B.is_pressed), str(buttons.X.is_pressed), str(buttons.Y.is_pressed),
                str(buttons.LB.is_pressed), str(buttons.RB.is_pressed), float2str(left_trigger), float2str(right_trigger)
            )

            live.update(table)

            if is_any_key_pressed(['\x1b', 'q']):
                break

            time.sleep(0.05)
